{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook regroups the code sample of the video below, which is a part of the [Hugging Face course](https://huggingface.co/course)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/VN67ZpN33Ss?rel=0&amp;controls=0&amp;showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#@title\n",
    "from IPython.display import HTML\n",
    "\n",
    "HTML('<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/VN67ZpN33Ss?rel=0&amp;controls=0&amp;showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install the Transformers and Datasets libraries to run this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install datasets transformers[sentencepiece]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_labels(offsets, answer_start, answer_end, sequence_ids):\n",
    "    idx = 0\n",
    "    while sequence_ids[idx] != 1:\n",
    "        idx += 1\n",
    "    context_start = idx\n",
    "    while sequence_ids[idx] == 1:\n",
    "        idx += 1\n",
    "    context_end = idx - 1\n",
    "\n",
    "    if offsets[context_start][0] > answer_end or offsets[context_end][1] < answer_start:\n",
    "        return (0, 0)\n",
    "    else:\n",
    "        idx = context_start\n",
    "        while idx <= context_end and offsets[idx][0] <= answer_start:\n",
    "            idx += 1\n",
    "        start_position = idx - 1\n",
    "\n",
    "        idx = context_end\n",
    "        while idx >= context_start and offsets[idx][1] >= answer_end:\n",
    "            idx -= 1\n",
    "        end_position = idx + 1\n",
    "    \n",
    "        return start_position, end_position"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess_validation_examples(examples):\n",
    "    questions = [q.strip() for q in examples[\"question\"]]\n",
    "    inputs = tokenizer(\n",
    "        examples[\"question\"],\n",
    "        examples[\"context\"],\n",
    "        truncation=\"only_second\",\n",
    "        padding=\"max_length\",\n",
    "        max_length=384,\n",
    "        stride=128,\n",
    "        return_overflowing_tokens=True,\n",
    "        return_offsets_mapping=True,\n",
    "    )\n",
    "    \n",
    "    offset_mapping = inputs[\"offset_mapping\"]\n",
    "    sample_map = inputs.pop(\"overflow_to_sample_mapping\")\n",
    "    inputs[\"start_positions\"] = []\n",
    "    inputs[\"end_positions\"] = []\n",
    "    inputs[\"example_id\"] = []\n",
    "\n",
    "    for i, offset in enumerate(offset_mapping):\n",
    "        sample_idx = sample_map[i]\n",
    "        inputs[\"example_id\"].append(examples[\"id\"][sample_idx])\n",
    "        sequence_ids = inputs.sequence_ids(i)\n",
    "        offset_mapping[i] = [(o if s == 1 else None) for o, s in zip(offset, sequence_ids)]\n",
    "        start, end = find_labels(\n",
    "            offset, examples[\"answer_start\"][sample_idx], examples[\"answer_end\"][sample_idx], sequence_ids\n",
    "        )\n",
    "        \n",
    "        inputs[\"start_positions\"].append(start)\n",
    "        inputs[\"end_positions\"].append(end)\n",
    "\n",
    "    return inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "model_checkpoint = \"distilbert-base-cased-distilled-squad\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n",
    "\n",
    "raw_datasets = load_dataset(\"squad\")\n",
    "raw_datasets = raw_datasets.remove_columns([\"title\"])\n",
    "\n",
    "def prepare_data(example):\n",
    "    answer = example[\"answers\"][\"text\"][0]\n",
    "    example[\"answer_start\"] = example[\"answers\"][\"answer_start\"][0]\n",
    "    example[\"answer_end\"] = example[\"answer_start\"] + len(answer)\n",
    "    return example\n",
    "\n",
    "validation_set = raw_datasets[\"validation\"].map(prepare_data, remove_columns=[\"answers\"])\n",
    "validation_features = validation_set.map(\n",
    "    preprocess_validation_examples,\n",
    "    batched=True,\n",
    "    remove_columns=validation_set.column_names,\n",
    ")\n",
    "len(validation_set), len(validation_features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TFAutoModelForQuestionAnswering\n",
    "\n",
    "tf_validation_set = validation_features.to_tf_dataset(\n",
    "    columns=[\"input_ids\", \"attention_mask\"],\n",
    "    batch_size=16,\n",
    "    shuffle=False,\n",
    ")\n",
    "\n",
    "model = TFAutoModelForQuestionAnswering.from_pretrained(model_checkpoint)\n",
    "predictions = model.predict(tf_validation_set)\n",
    "\n",
    "start_logits = predictions.start_logits\n",
    "end_logits = predictions.end_logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "\n",
    "example_to_feature = collections.defaultdict(list)\n",
    "for idx, feature in enumerate(validation_features):\n",
    "    example_id = feature[\"example_id\"]\n",
    "    example_to_feature[example_id].append(idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score[start_pos, end_pos] = start_probabilities[start_pos] * end_probabilities[end_pos]\n",
    "logit_score[start_pos, end_pos] = start_logits[start_pos] + end_logits[end_pos]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "start_logit = start_logits[0]\n",
    "end_logit = end_logits[0]\n",
    "offsets = validation_features[0][\"offset_mapping\"]\n",
    "\n",
    "context = validation_set[0][\"context\"]\n",
    "\n",
    "start_indexes = np.argsort(start_logit)[-1 : -21 : -1].tolist()\n",
    "end_indexes = np.argsort(end_logit)[-1 : -21 : -1].tolist()\n",
    "answers = []\n",
    "for start_index in start_indexes:\n",
    "    for end_index in end_indexes:\n",
    "        # Predicting (0, 0) means no answer.\n",
    "        if start_index == 0 and end_index == 0:\n",
    "            answers.append({\"text\": \"\", \"logit_score\": start_logit[start_index] + end_logit[end_index]})\n",
    "        # Skip answers that are not fully in the context.\n",
    "        elif offsets[start_index] is None or offsets[end_index] is None:\n",
    "            continue\n",
    "        # Skip answers with a length that is either < 0 or > max_answer_length.\n",
    "        elif end_index < start_index or end_index - start_index + 1 > 30:\n",
    "            continue\n",
    "        else:\n",
    "            answers.append({\n",
    "                \"text\": context[offsets[start_index][0]: offsets[end_index][1]],\n",
    "                \"logit_score\": start_logit[start_index] + end_logit[end_index],\n",
    "            })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predicted_answer = max(answers, key = lambda x: x[\"logit_score\"])\n",
    "print(f\"Predicted answer: {predicted_answer}\")\n",
    "\n",
    "answer_start = validation_set[0][\"answer_start\"]\n",
    "answer_end = validation_set[0][\"answer_end\"]\n",
    "right_answer = context[answer_start: answer_end]\n",
    "print(f\"Theorerical answer: {right_answer}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predicted_answers = {}\n",
    "for example in tqdm(validation_set):\n",
    "    example_id = example[\"id\"]\n",
    "    context = example[\"context\"]\n",
    "    answers = []\n",
    "    \n",
    "    for feature_index in example_to_feature[example_id]:\n",
    "        start_logit = start_logits[feature_index]\n",
    "        end_logit = end_logits[feature_index]\n",
    "        offsets = validation_features[feature_index][\"offset_mapping\"]\n",
    "\n",
    "        start_indexes = np.argsort(start_logit)[-1 : -11 : -1].tolist()\n",
    "        end_indexes = np.argsort(end_logit)[-1 : -11 : -1].tolist()\n",
    "        for start_index in start_indexes:\n",
    "            for end_index in end_indexes:\n",
    "                # Predicting (0, 0) means no answer.\n",
    "                if start_index == 0 and end_index == 0:\n",
    "                    answers.append({\"text\": \"\", \"logit_score\": start_logit[start_index] + end_logit[end_index]})\n",
    "                # Skip answers that are not fully in the context.\n",
    "                elif offsets[start_index] is None or offsets[end_index] is None:\n",
    "                    continue\n",
    "                # Skip answers with a length that is either < 0 or > max_answer_length.\n",
    "                elif end_index < start_index or end_index - start_index + 1 > 30:\n",
    "                    continue\n",
    "                else:\n",
    "                    answers.append({\n",
    "                        \"text\": context[offsets[start_index][0]: offsets[end_index][1]],\n",
    "                        \"logit_score\": start_logit[start_index] + end_logit[end_index],\n",
    "                    })\n",
    "\n",
    "    best_answer = max(answers, key= lambda x: x[\"logit_score\"])\n",
    "    predicted_answers[example_id] = best_answer[\"text\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "name": "The Post processing step in Question Answering (TensorFlow)",
   "provenance": []
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
